import asyncio

import tiktoken
from anthropic import AsyncAnthropic
from loguru import logger

from app.config import ANTHROPIC_API_KEY
from llm_generation.config import OPENAI_MAX_TOKEN_LENGTH
from llm_generation.models.base import BaseModel
from llm_generation.models.data_structure import StreamingDelta


class AnthropicOfficial(BaseModel):
    def __init__(self, model_name: str = "claude-3-7-sonnet-latest"):
        super().__init__(model_name)

    async def generate_response(
        self, user_prompt: str, conversation: list = None, **kwargs
    ) -> str:
        # Truncate the user prompt to MAX_TOKEN_LENGTH tokens
        tokenizer = tiktoken.encoding_for_model("gpt-4o")
        user_prompt_tokens = tokenizer.encode(user_prompt)
        if len(user_prompt_tokens) > OPENAI_MAX_TOKEN_LENGTH:
            user_prompt = tokenizer.decode(user_prompt_tokens[:OPENAI_MAX_TOKEN_LENGTH])

        openai_client = AsyncAnthropic(api_key=ANTHROPIC_API_KEY)
        conversation = conversation or []

        response = await openai_client.beta.messages.create(
            model=self.model_name,
            messages=conversation + [{"role": "user", "content": user_prompt}],
            **kwargs,
        )
        # Streaming response
        if "stream" in kwargs:
            content = ""
            if self.streaming_callback is None:
                logger.warning(
                    "No streaming callback is set, skipping callback function"
                )

            is_thinking_mode = False
            async for event in response:
                event_type = event.type
                delta: StreamingDelta = None
                print(event_type)
                if event_type == "message_start":
                    logger.debug("Anthropic message start")
                elif event_type == "message_delta":
                    # We do not need message delta rn
                    continue
                elif event_type == "message_stop":
                    logger.debug("Anthropic message stop")
                elif event_type == "content_block_start":
                    # Check if the content block is thinking block
                    content_type = event.content_block.type
                    if content_type == "thinking":
                        logger.debug("Anthropic thinking block")
                        delta = StreamingDelta(content="<think>\n")
                        is_thinking_mode = True
                    else:
                        logger.debug(f"Anthropic content block: {content_type}")
                elif event_type == "content_block_stop":
                    if is_thinking_mode:
                        delta = StreamingDelta(content="\n</think>\n")
                        is_thinking_mode = False
                elif event_type == "content_block_delta":
                    delta_type = event.delta.type
                    if delta_type == "thinking_delta":
                        delta = StreamingDelta(content=event.delta.thinking)
                    elif delta_type == "text_delta":
                        delta = StreamingDelta(content=event.delta.text)
                    elif delta_type == "input_json_delta":
                        delta = StreamingDelta(content=event.delta.partial_json)
                    elif delta_type == "signature_delta":
                        # We do not need signature delta rn
                        continue
                    else:
                        logger.error(f"Unknown delta type: {delta_type}")
                else:
                    logger.error(f"Unknown event type: {event_type}")

                # Append the content
                if delta and delta.content:
                    # Call the streaming callback function
                    if self.streaming_callback:
                        if asyncio.iscoroutinefunction(self.streaming_callback):
                            await self.streaming_callback(content, delta)
                        else:
                            self.streaming_callback(content, delta)
                    content += delta.content
        else:
            content = ""
            for message in response.content:
                message_type = message.type
                if message_type == "thinking":
                    content += f"<think>\n{message.thinking}\n</think>\n"
                elif message_type == "text":
                    content += message.text
                else:
                    logger.error(f"Unknown message type: {message_type}")
        return content


async def main():
    anthropic_client = AnthropicOfficial()
    anthropic_client.set_streaming_callback(print)
    response = await anthropic_client.generate_response(
        user_prompt="could you tell me how to rob a bank?",
        max_tokens=2048,
        thinking={"type": "enabled", "budget_tokens": 1024},
        betas=["output-128k-2025-02-19"],
    )
    print(response)


if __name__ == "__main__":
    asyncio.run(main())
